{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d2bd9e7a",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dfe6983",
   "metadata": {},
   "source": [
    "# CLIP\n",
    "\n",
    "[CLIP](https://github.com/openai/CLIP) (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs by OpenAI. \n",
    "\n",
    "It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.\n",
    "\n",
    "\n",
    "![](https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f15a94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment the code below if you use google colab:\n",
    "#%pip install git+https://github.com/openai/CLIP.git\n",
    "#%mkdir data\n",
    "#%cd data\n",
    "#!wget https://raw.githubusercontent.com/dataflowr/notebooks/master/Module19/data/cat.jpg\n",
    "#!wget https://raw.githubusercontent.com/dataflowr/notebooks/master/Module19/data/dog.png\n",
    "#!wget https://raw.githubusercontent.com/dataflowr/notebooks/master/Module19/data/caltech101_full.json\n",
    "#%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04f6ebdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import clip\n",
    "from PIL import Image\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a68e8952",
   "metadata": {},
   "outputs": [],
   "source": [
    "dog_image = Image.open(\"data/dog.png\")\n",
    "cat_image = Image.open(\"data/cat.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ced50e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dog_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02877e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae1b2570",
   "metadata": {},
   "source": [
    "# First use of CLIP\n",
    "\n",
    "Use the [code snippets](https://github.com/openai/CLIP#usage) in order to get the good labels for the 2 images above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "377cd111",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b331385b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a3542bcc",
   "metadata": {},
   "source": [
    "Note that in the code provided, the features for the text and for the image are not used. Check that the probabilities can be recovered from these features directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0c020f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "db48900a",
   "metadata": {},
   "source": [
    "# Building a classifier from CLIP\n",
    "\n",
    "Check that the classifier below is working for the images above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84ba03ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Classifier_CLIP:\n",
    "    def __init__(self, labels):\n",
    "        self.labels = labels\n",
    "        self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "        self.model, self.preprocess = clip.load(\"ViT-B/32\", device=self.device)\n",
    "        self.text = clip.tokenize(labels).to(self.device)\n",
    "        \n",
    "    def classify(self, image_pil, verbose=False):\n",
    "        image = self.preprocess(image_pil).unsqueeze(0).to(device)\n",
    "        with torch.no_grad():\n",
    "            logits_per_image, logits_per_text = self.model(image, self.text)\n",
    "            probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n",
    "            if verbose:\n",
    "                print('predicted class: ', self.labels[np.argmax(probs)])\n",
    "        return np.argmax(probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f0eae4c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b98f1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#classifier.classify(dog_image, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7362cd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "#classifier.classify(cat_image, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57f0cdc2",
   "metadata": {},
   "source": [
    "# Testing the classifier on Caltech 101 \n",
    "\n",
    "Now we want to see what are the performances of this classifier on the [Caltech 101](https://data.caltech.edu/records/mzrjq-6wc02) dataset.\n",
    "\n",
    "You first need to download the dataset with [torchvision](https://pytorch.org/vision/stable/generated/torchvision.datasets.Caltech101.html#torchvision.datasets.Caltech101)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "528efc2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "#caltech_data = torchvision.datasets.Caltech101('data/', download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bec2d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "caltech_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4122273",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 4578\n",
    "caltech_data[k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5f01cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "caltech_data[k][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656e46b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "caltech_data.categories[caltech_data[k][1]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f25f74a2",
   "metadata": {},
   "source": [
    "Now, you need to add methods to the `Classifier_CLIP` class. You can see on this [nice blogpost of Sean Osier](https://www.seanosier.com/2021/03/20/python-add-method-existing-class/) how to do it.\n",
    "\n",
    "First make a method to create the texts corresponding to the labels and tokenize them. Once this is done check what is the prediction made by the classifier on the image above. Is it right?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee60b982",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "64b77546",
   "metadata": {},
   "source": [
    "Now, add two methods: `predict` will take a batch of images and compute the corresponding probabilities and predictions and `test` will take as input a dataloader and use predict to compute the accuracy of the classifier on the dataset.\n",
    "\n",
    "Hint: to create the dataloader you can use `from more_itertools import chunked`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80c9dbd1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14be629c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "30f939ed",
   "metadata": {},
   "source": [
    "# Better performances with GPT!\n",
    "\n",
    "Using the idea of [Visual Classification via Description from Large Language Models](https://github.com/sachit-menon/classify_by_description_release/tree/master#visual-classification-via-description-from-large-language-models) by Sachit Menon, Carl Vondrick (ICLR 2023), try to get better performances!\n",
    "![](https://raw.githubusercontent.com/sachit-menon/classify_by_description_release/master/figs/latent-points.png)\n",
    "\n",
    "If you do not want to do prompt engineering, there are descriptors provided in the file `caltech101_full.json`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2e8ef32",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "def load_descriptors(self, filename):\n",
    "    # Opening JSON file\n",
    "    f = open(filename)\n",
    "    self.descriptors = json.load(f)\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "625b92b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "Classifier_CLIP.load_descriptors = load_descriptors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87a16b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_caltech.load_descriptors('data/caltech101_full.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43dd9472",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_caltech.descriptors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f23a3a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3b779b8a",
   "metadata": {},
   "source": [
    "# Prompt engineering for image classification\n",
    "\n",
    "Try to get better descriptors!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa322b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"OPENAI_API_KEY\"] = # your key here\n",
    "\n",
    "def stringtolist(description):\n",
    "    return [descriptor[2:] for descriptor in description.split('\\n') if (descriptor != '') and (descriptor.startswith('- '))]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8237b855",
   "metadata": {},
   "source": [
    "I am using the same [prompts as in the original paper](https://github.com/sachit-menon/classify_by_description_release/blob/master/generate_descriptors.py) adapted to the new API. It could be improved..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7737b5e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import OpenAI\n",
    "client = OpenAI()\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "  model=\"gpt-3.5-turbo\",\n",
    "  messages=[\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "    {\"role\": \"user\", \"content\": \"What are useful visual features for distinguishing a lemur in a photo?\"},\n",
    "    {\"role\": \"assistant\", \"content\": \"\"\"There are several useful visual features to tell there is a lemur in a photo:\n",
    "- four-limbed primate\n",
    "- black, grey, white, brown, or red-brown\n",
    "- wet and hairless nose with curved nostrils\n",
    "- long tail\n",
    "- large eyes\n",
    "- furry bodies\n",
    "- clawed hands and feet\"\"\"},\n",
    "    {\"role\": \"user\", \"content\": \"What are useful visual features for distinguishing a television in a photo?\"},\n",
    "    {\"role\": \"assistant\", \"content\": \"\"\"There are several useful visual features to tell there is a television in a photo:\n",
    "- electronic device\n",
    "- black or grey\n",
    "- a large, rectangular screen\n",
    "- a stand or mount to support the screen\n",
    "- one or more speakers\n",
    "- a power cord\n",
    "- input ports for connecting to other devices\n",
    "- a remote control\"\"\"},\n",
    "    {\"role\": \"user\", \"content\": \"What are useful visual features for distinguishing a dragonfly in a photo? Provide an answer following the above pattern, give only the list of visual features.\"}\n",
    "  ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b287d961",
   "metadata": {},
   "outputs": [],
   "source": [
    "response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb7c425",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_prompt(category_name: str):\n",
    "    # you can replace the examples with whatever you want; these were random and worked, could be improved\n",
    "    return [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "    {\"role\": \"user\", \"content\": \"What are useful visual features for distinguishing a lemur in a photo?\"},\n",
    "    {\"role\": \"assistant\", \"content\": \"\"\"There are several useful visual features to tell there is a lemur in a photo:\n",
    "- four-limbed primate\n",
    "- black, grey, white, brown, or red-brown\n",
    "- wet and hairless nose with curved nostrils\n",
    "- long tail\n",
    "- large eyes\n",
    "- furry bodies\n",
    "- clawed hands and feet\"\"\"},\n",
    "    {\"role\": \"user\", \"content\": \"What are useful visual features for distinguishing a television in a photo?\"},\n",
    "    {\"role\": \"assistant\", \"content\": \"\"\"There are several useful visual features to tell there is a television in a photo:\n",
    "- electronic device\n",
    "- black or grey\n",
    "- a large, rectangular screen\n",
    "- a stand or mount to support the screen\n",
    "- one or more speakers\n",
    "- a power cord\n",
    "- input ports for connecting to other devices\n",
    "- a remote control\"\"\"},\n",
    "    {\"role\": \"user\", \"content\": f\"What are useful visual features for distinguishing a {category_name} in a photo? Provide an answer following the above pattern, give only the list of visual features.\"}\n",
    "  ]\n",
    "\n",
    "def obtain_descriptors_and_save(filename, class_list):\n",
    "    responses = {}\n",
    "    descriptors = {}\n",
    "    prompts = [generate_prompt(category.replace('_', ' ')) for category in class_list]\n",
    "    client = OpenAI()\n",
    "\n",
    "    responses = [client.chat.completions.create(\n",
    "      model=\"gpt-3.5-turbo\",\n",
    "      messages= prompt\n",
    "    ) for prompt in prompts]\n",
    "    \n",
    "    \n",
    "    response_texts = [resp.choices[0].message.content for resp in responses]\n",
    "    descriptors_list = [stringtolist(response_text) for response_text in response_texts]\n",
    "    descriptors = {cat: descr for cat, descr in zip(class_list, descriptors_list)}\n",
    "\n",
    "    # save descriptors to json file\n",
    "    if not filename.endswith('.json'):\n",
    "        filename += '.json'\n",
    "    with open(filename, 'w') as fp:\n",
    "        json.dump(descriptors, fp)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
